{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as trans\n",
    "\n",
    "trainSet = dsets.MNIST(root='./data', train = True, transform=trans.ToTensor(), download=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of images 60000\n",
      "Type <class 'torch.Tensor'>\n",
      "Size of each image torch.Size([1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print('Number of images {}'.format(len(trainSet)))\n",
    "print('Type {}'.format(type(trainSet[0][0])))\n",
    "print('Size of each image {}'.format(trainSet[0][0].size()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 28, 28])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainSet[0][0].size() # each image is of size 28 X 28 and is greyscale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MultiLogisticModel(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim):\n",
    "        super(MultiLogisticModel, self).__init__()\n",
    "        self.linear = nn.Linear(in_dim, out_dim)\n",
    "        \n",
    "    def forward(self,x):        \n",
    "        out =self.linear(x)\n",
    "        return out "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length 60000 Size torch.Size([1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print('Length {} Size {}'.format(len(trainSet), trainSet[0][0].size()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "in_dim = 28*28 #input dimensions\n",
    "out_dim = 10 #output dimensions\n",
    "model = MultiLogisticModel(in_dim, out_dim) #instantiate model\n",
    "criterion = nn.CrossEntropyLoss()  #instantiate the loss class\n",
    "optimiser = torch.optim.SGD(model.parameters(), lr=.1) #instantiate the optimiser class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batchSize = 200\n",
    "epochs = 1\n",
    "trainloader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=batchSize, shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "196.61259549856186\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs): \n",
    "\n",
    "    runningLoss = 0.0\n",
    "    for i, (images,labels) in enumerate(trainloader):\n",
    "        images = images.view(-1, 28 * 28)       \n",
    "        optimiser.zero_grad()\n",
    "        outputs = model(images)        \n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimiser.step()                      \n",
    "        runningLoss += loss.item()\n",
    "      \n",
    "    print(runningLoss)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([200, 10])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = model.forward(images) \n",
    "predicted.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predictions tensor([-3.7056,  0.0625,  1.0488, -0.0671, -0.3245,  1.4037,  4.0016,\n",
      "        -3.3711,  2.2218, -1.1201])\n",
      "labels 6\n"
     ]
    }
   ],
   "source": [
    "print('predictions {}'.format(predicted[0]))\n",
    "print('labels {}'.format(labels[0]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 6,  8,  1])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels[0:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def successRate(predicted, labels):\n",
    "    predict = [np.argmax(p.detach().numpy()) for p in predicted]\n",
    "    actual =  [labels[i].item() for i in range(len(predicted))]\n",
    "    correct = [i for i, j in zip(predict, actual) if i == j]\n",
    "    return(len(correct)/(len(predict)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.92"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "successRate(predicted, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.886"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testSet = dsets.MNIST(root='./data', train = False, transform=trans.ToTensor(), download=True)\n",
    "testloader = torch.utils.data.DataLoader(dataset=trainSet,  batch_size = 10000, shuffle = True)\n",
    "\n",
    "testData = iter(testloader)\n",
    "images, labels = testData.next()\n",
    "output = model(images.view(-1, 28 * 28))\n",
    "successRate(output,labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-3.4977,  4.9538, -0.1545,  0.4681, -1.7700,  0.5497,  0.0962,\n",
      "         -1.0039,  1.2567, -0.7102],\n",
      "        [-2.1123, -1.0816,  1.0725, -0.7383,  0.3273,  0.1948,  4.3535,\n",
      "         -1.8619,  0.1997, -0.7896],\n",
      "        [ 9.4955, -6.7109,  2.8008,  0.1442, -4.8314,  3.1091, -0.1016,\n",
      "         -1.7326,  1.2769, -3.4950]])\n",
      "tensor([ 1,  6,  0])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.905"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testSet = dsets.MNIST(root='./data', train = False, transform=trans.ToTensor(), download=True)\n",
    "testloader = torch.utils.data.DataLoader(dataset=trainSet, batch_size = batchSize, shuffle = True)\n",
    "\n",
    "testData = iter(testloader)\n",
    "images, labels = testData.next()\n",
    "output = model(images.view(-1, 28 * 28))\n",
    "#successRate(output,labels)\n",
    "predicted = model.forward(images.view(-1, 28 * 28))   \n",
    "print(predicted[0:3])\n",
    "print(labels[0:3])\n",
    "successRate(predicted,labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred = [np.argmax(p.detach().numpy()) for p in output] \n",
    "incorrect = ([i for i, j in zip(pred, labels) if i != j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 28, 28])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images[0].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "784"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
